1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: collections:: HashSet ;
19- use std:: sync:: Arc ;
20-
21- use pyo3:: exceptions:: PyKeyError ;
22- use pyo3:: prelude:: * ;
23-
24- use crate :: errors:: { PyDataFusionError , PyDataFusionResult } ;
25- use crate :: utils:: wait_for_future;
18+ use crate :: dataset:: Dataset ;
19+ use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
20+ use crate :: utils:: { validate_pycapsule, wait_for_future} ;
21+ use async_trait:: async_trait;
22+ use datafusion:: common:: DataFusionError ;
2623use datafusion:: {
2724 arrow:: pyarrow:: ToPyArrow ,
2825 catalog:: { CatalogProvider , SchemaProvider } ,
2926 datasource:: { TableProvider , TableType } ,
3027} ;
28+ use datafusion_ffi:: table_provider:: { FFI_TableProvider , ForeignTableProvider } ;
29+ use pyo3:: exceptions:: PyKeyError ;
30+ use pyo3:: prelude:: * ;
31+ use pyo3:: types:: PyCapsule ;
32+ use std:: any:: Any ;
33+ use std:: collections:: HashSet ;
34+ use std:: sync:: Arc ;
3135
3236#[ pyclass( name = "Catalog" , module = "datafusion" , subclass) ]
3337pub struct PyCatalog {
@@ -50,8 +54,8 @@ impl PyCatalog {
5054 }
5155}
5256
53- impl PyDatabase {
54- pub fn new ( database : Arc < dyn SchemaProvider > ) -> Self {
57+ impl From < Arc < dyn SchemaProvider > > for PyDatabase {
58+ fn from ( database : Arc < dyn SchemaProvider > ) -> Self {
5559 Self { database }
5660 }
5761}
@@ -75,7 +79,7 @@ impl PyCatalog {
7579 #[ pyo3( signature = ( name="public" ) ) ]
7680 fn database ( & self , name : & str ) -> PyResult < PyDatabase > {
7781 match self . catalog . schema ( name) {
78- Some ( database) => Ok ( PyDatabase :: new ( database) ) ,
82+ Some ( database) => Ok ( database. into ( ) ) ,
7983 None => Err ( PyKeyError :: new_err ( format ! (
8084 "Database with name {name} doesn't exist."
8185 ) ) ) ,
@@ -92,6 +96,13 @@ impl PyCatalog {
9296
9397#[ pymethods]
9498impl PyDatabase {
99+ #[ new]
100+ fn new ( schema_provider : PyObject ) -> Self {
101+ let schema_provider =
102+ Arc :: new ( RustWrappedPySchemaProvider :: new ( schema_provider) ) as Arc < dyn SchemaProvider > ;
103+ schema_provider. into ( )
104+ }
105+
95106 fn names ( & self ) -> HashSet < String > {
96107 self . database . table_names ( ) . into_iter ( ) . collect ( )
97108 }
@@ -145,3 +156,133 @@ impl PyTable {
145156 // fn has_exact_statistics
146157 // fn supports_filter_pushdown
147158}
159+
160+ #[ derive( Debug ) ]
161+ struct RustWrappedPySchemaProvider {
162+ schema_provider : PyObject ,
163+ owner_name : Option < String > ,
164+ }
165+
166+ impl RustWrappedPySchemaProvider {
167+ fn new ( schema_provider : PyObject ) -> Self {
168+ let owner_name = Python :: with_gil ( |py| {
169+ schema_provider
170+ . bind ( py)
171+ . getattr ( "owner_name" )
172+ . ok ( )
173+ . map ( |name| name. to_string ( ) )
174+ } ) ;
175+
176+ Self {
177+ schema_provider,
178+ owner_name,
179+ }
180+ }
181+
182+ fn table_inner ( & self , name : & str ) -> PyResult < Option < Arc < dyn TableProvider > > > {
183+ Python :: with_gil ( |py| {
184+ let provider = self . schema_provider . bind ( py) ;
185+ let py_table_method = provider. getattr ( "table" ) ?;
186+
187+ let py_table = py_table_method. call ( ( name, ) , None ) ?;
188+ if py_table. is_none ( ) {
189+ return Ok ( None ) ;
190+ }
191+
192+ if py_table. hasattr ( "__datafusion_table_provider__" ) ? {
193+ let capsule = provider. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
194+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
195+ validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
196+
197+ let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
198+ let provider: ForeignTableProvider = provider. into ( ) ;
199+
200+ Ok ( Some ( Arc :: new ( provider) as Arc < dyn TableProvider > ) )
201+ } else {
202+ let ds = Dataset :: new ( & py_table, py) . map_err ( py_datafusion_err) ?;
203+
204+ Ok ( Some ( Arc :: new ( ds) as Arc < dyn TableProvider > ) )
205+ }
206+ } )
207+ }
208+ }
209+
210+ #[ async_trait]
211+ impl SchemaProvider for RustWrappedPySchemaProvider {
212+ fn owner_name ( & self ) -> Option < & str > {
213+ self . owner_name . as_deref ( )
214+ }
215+
216+ fn as_any ( & self ) -> & dyn Any {
217+ self
218+ }
219+
220+ fn table_names ( & self ) -> Vec < String > {
221+ Python :: with_gil ( |py| {
222+ let provider = self . schema_provider . bind ( py) ;
223+ provider
224+ . getattr ( "table_names" )
225+ . and_then ( |names| names. extract :: < Vec < String > > ( ) )
226+ . unwrap_or_default ( )
227+ } )
228+ }
229+
230+ async fn table (
231+ & self ,
232+ name : & str ,
233+ ) -> datafusion:: common:: Result < Option < Arc < dyn TableProvider > > , DataFusionError > {
234+ self . table_inner ( name) . map_err ( to_datafusion_err)
235+ }
236+
237+ fn register_table (
238+ & self ,
239+ name : String ,
240+ table : Arc < dyn TableProvider > ,
241+ ) -> datafusion:: common:: Result < Option < Arc < dyn TableProvider > > > {
242+ let py_table = PyTable :: new ( table) ;
243+ Python :: with_gil ( |py| {
244+ let provider = self . schema_provider . bind ( py) ;
245+ let _ = provider
246+ . call_method1 ( "register_table" , ( name, py_table) )
247+ . map_err ( to_datafusion_err) ?;
248+ // Since the definition of `register_table` says that an error
249+ // will be returned if the table already exists, there is no
250+ // case where we want to return a table provider as output.
251+ Ok ( None )
252+ } )
253+ }
254+
255+ fn deregister_table (
256+ & self ,
257+ name : & str ,
258+ ) -> datafusion:: common:: Result < Option < Arc < dyn TableProvider > > > {
259+ Python :: with_gil ( |py| {
260+ let provider = self . schema_provider . bind ( py) ;
261+ let table = provider
262+ . call_method1 ( "deregister_table" , ( name, ) )
263+ . map_err ( to_datafusion_err) ?;
264+ if table. is_none ( ) {
265+ return Ok ( None ) ;
266+ }
267+
268+ // If we can turn this table provider into a `Dataset`, return it.
269+ // Otherwise, return None.
270+ let dataset = match Dataset :: new ( & table, py) {
271+ Ok ( dataset) => Some ( Arc :: new ( dataset) as Arc < dyn TableProvider > ) ,
272+ Err ( _) => None ,
273+ } ;
274+
275+ Ok ( dataset)
276+ } )
277+ }
278+
279+ fn table_exist ( & self , name : & str ) -> bool {
280+ Python :: with_gil ( |py| {
281+ let provider = self . schema_provider . bind ( py) ;
282+ provider
283+ . call_method1 ( "table_exist" , ( name, ) )
284+ . and_then ( |pyobj| pyobj. extract ( ) )
285+ . unwrap_or ( false )
286+ } )
287+ }
288+ }
0 commit comments