1515# specific language governing permissions and limitations
1616# under the License.
1717
18+ import datafusion as dfn
1819import pyarrow as pa
20+ import pyarrow .dataset as ds
1921import pytest
22+ from datafusion import SessionContext , Table
2023
2124
2225# Note we take in `database` as a variable even though we don't use
@@ -27,7 +30,7 @@ def test_basic(ctx, database):
2730 ctx .catalog ("non-existent" )
2831
2932 default = ctx .catalog ()
30- assert default .names () == [ "public" ]
33+ assert default .names () == { "public" }
3134
3235 for db in [default .database ("public" ), default .database ()]:
3336 assert db .names () == {"csv1" , "csv" , "csv2" }
@@ -41,3 +44,100 @@ def test_basic(ctx, database):
4144 pa .field ("float" , pa .float64 (), nullable = True ),
4245 ]
4346 )
47+
48+
49+ class CustomTableProvider :
50+ def __init__ (self ):
51+ pass
52+
53+
54+ def create_dataset () -> pa .dataset .Dataset :
55+ batch = pa .RecordBatch .from_arrays (
56+ [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
57+ names = ["a" , "b" ],
58+ )
59+ return ds .dataset ([batch ])
60+
61+
62+ class CustomSchemaProvider :
63+ def __init__ (self ):
64+ self .tables = {"table1" : create_dataset ()}
65+
66+ def table_names (self ) -> set [str ]:
67+ return set (self .tables .keys ())
68+
69+ def register_table (self , name : str , table : Table ):
70+ self .tables [name ] = table
71+
72+ def deregister_table (self , name , cascade : bool = True ):
73+ del self .tables [name ]
74+
75+
76+ class CustomCatalogProvider :
77+ def __init__ (self ):
78+ self .schemas = {"my_schema" : CustomSchemaProvider ()}
79+
80+ def schema_names (self ) -> set [str ]:
81+ return set (self .schemas .keys ())
82+
83+ def schema (self , name : str ):
84+ return self .schemas [name ]
85+
86+ def register_schema (self , name : str , schema : dfn .catalog .Schema ):
87+ self .schemas [name ] = schema
88+
89+ def deregister_schema (self , name , cascade : bool ):
90+ del self .schemas [name ]
91+
92+
93+ def test_python_catalog_provider (ctx : SessionContext ):
94+ ctx .register_catalog_provider ("my_catalog" , CustomCatalogProvider ())
95+
96+ # Check the default catalog provider
97+ assert ctx .catalog ("datafusion" ).names () == {"public" }
98+
99+ my_catalog = ctx .catalog ("my_catalog" )
100+ assert my_catalog .names () == {"my_schema" }
101+
102+ my_catalog .register_schema ("second_schema" , CustomSchemaProvider ())
103+ assert my_catalog .schema_names () == {"my_schema" , "second_schema" }
104+
105+ my_catalog .deregister_schema ("my_schema" )
106+ assert my_catalog .schema_names () == {"second_schema" }
107+
108+
109+ def test_python_schema_provider (ctx : SessionContext ):
110+ catalog = ctx .catalog ()
111+
112+ catalog .deregister_schema ("public" )
113+
114+ catalog .register_schema ("test_schema1" , CustomSchemaProvider ())
115+ assert catalog .names () == {"test_schema1" }
116+
117+ catalog .register_schema ("test_schema2" , CustomSchemaProvider ())
118+ catalog .deregister_schema ("test_schema1" )
119+ assert catalog .names () == {"test_schema2" }
120+
121+
122+ def test_python_table_provider (ctx : SessionContext ):
123+ catalog = ctx .catalog ()
124+
125+ catalog .register_schema ("custom_schema" , CustomSchemaProvider ())
126+ schema = catalog .schema ("custom_schema" )
127+
128+ assert schema .table_names () == {"table1" }
129+
130+ schema .deregister_table ("table1" )
131+ schema .register_table ("table2" , create_dataset ())
132+ assert schema .table_names () == {"table2" }
133+
134+ # Use the default schema instead of our custom schema
135+
136+ schema = catalog .schema ()
137+
138+ schema .register_table ("table3" , create_dataset ())
139+ assert schema .table_names () == {"table3" }
140+
141+ schema .deregister_table ("table3" )
142+ schema .register_table ("table4" , create_dataset ())
143+ assert schema .table_names () == {"table4" }
0 commit comments