|
5 | 5 | from llvmlite import ir as llvmir |
6 | 6 | from numba.core import datamodel, types |
7 | 7 | from numba.core.datamodel.models import OpaqueModel, PrimitiveModel, StructModel |
8 | | -from numba.core.extending import register_model |
9 | 8 |
|
10 | 9 | from numba_dpex.core.exceptions import UnreachableError |
11 | 10 | from numba_dpex.core.types.kernel_api.atomic_ref import AtomicRefType |
@@ -316,7 +315,7 @@ def __init__(self, dmm, fe_type): |
316 | 315 | super().__init__(dmm, fe_type, members) |
317 | 316 |
|
318 | 317 |
|
319 | | -def _init_data_model_manager() -> datamodel.DataModelManager: |
| 318 | +def _init_kernel_data_model_manager() -> datamodel.DataModelManager: |
320 | 319 | """Initializes a data model manager used by the SPRIVTarget. |
321 | 320 |
|
322 | 321 | SPIRV kernel functions for certain types of devices require an explicit |
@@ -370,43 +369,50 @@ def _init_data_model_manager() -> datamodel.DataModelManager: |
370 | 369 | return dmm |
371 | 370 |
|
372 | 371 |
|
373 | | -dpex_data_model_manager = _init_data_model_manager() |
| 372 | +def _init_dpjit_data_model_manager() -> datamodel.DataModelManager: |
| 373 | + # TODO: copy manager |
| 374 | + dmm = datamodel.default_manager |
374 | 375 |
|
| 376 | + # Register the USMNdArray type to USMArrayHostModel in numba's default data |
| 377 | + # model manager |
| 378 | + dmm.register(USMNdArray, USMArrayHostModel) |
375 | 379 |
|
376 | | -# Register the USMNdArray type to USMArrayDeviceModel in numba's default data |
377 | | -# model manager |
378 | | -register_model(USMNdArray)(USMArrayHostModel) |
| 380 | + # Register the DpnpNdArray type to USMArrayHostModel in numba's default data |
| 381 | + # model manager |
| 382 | + dmm.register(DpnpNdArray, USMArrayHostModel) |
379 | 383 |
|
380 | | -# Register the DpnpNdArray type to USMArrayHostModel in numba's default data |
381 | | -# model manager |
382 | | -register_model(DpnpNdArray)(USMArrayHostModel) |
| 384 | + # Register the DpctlSyclQueue type |
| 385 | + dmm.register(DpctlSyclQueue, SyclQueueModel) |
383 | 386 |
|
384 | | -# Register the DpctlSyclQueue type |
385 | | -register_model(DpctlSyclQueue)(SyclQueueModel) |
| 387 | + # Register the DpctlSyclEvent type |
| 388 | + dmm.register(DpctlSyclEvent, SyclEventModel) |
386 | 389 |
|
387 | | -# Register the DpctlSyclEvent type |
388 | | -register_model(DpctlSyclEvent)(SyclEventModel) |
| 390 | + # Register the RangeType type |
| 391 | + dmm.register(RangeType, RangeModel) |
389 | 392 |
|
390 | | -# Register the RangeType type |
391 | | -register_model(RangeType)(RangeModel) |
| 393 | + # Register the NdRangeType type |
| 394 | + dmm.register(NdRangeType, NdRangeModel) |
392 | 395 |
|
393 | | -# Register the NdRangeType type |
394 | | -register_model(NdRangeType)(NdRangeModel) |
| 396 | + # Register the GroupType type |
| 397 | + dmm.register(GroupType, EmptyStructModel) |
395 | 398 |
|
396 | | -# Register the GroupType type |
397 | | -register_model(GroupType)(EmptyStructModel) |
| 399 | + # Register the ItemType type |
| 400 | + dmm.register(ItemType, EmptyStructModel) |
398 | 401 |
|
399 | | -# Register the ItemType type |
400 | | -register_model(ItemType)(EmptyStructModel) |
| 402 | + # Register the NdItemType type |
| 403 | + dmm.register(NdItemType, EmptyStructModel) |
| 404 | + |
| 405 | + # Register the MDLocalAccessorType type |
| 406 | + dmm.register(DpctlMDLocalAccessorType, DpctlMDLocalAccessorModel) |
401 | 407 |
|
402 | | -# Register the NdItemType type |
403 | | -register_model(NdItemType)(EmptyStructModel) |
| 408 | + # Register the LocalAccessorType type |
| 409 | + dmm.register(LocalAccessorType, LocalAccessorModel) |
404 | 410 |
|
405 | | -# Register the MDLocalAccessorType type |
406 | | -register_model(DpctlMDLocalAccessorType)(DpctlMDLocalAccessorModel) |
| 411 | + # Register the KernelDispatcherType type |
| 412 | + dmm.register(KernelDispatcherType, OpaqueModel) |
| 413 | + |
| 414 | + return dmm |
407 | 415 |
|
408 | | -# Register the LocalAccessorType type |
409 | | -register_model(LocalAccessorType)(LocalAccessorModel) |
410 | 416 |
|
411 | | -# Register the KernelDispatcherType type |
412 | | -register_model(KernelDispatcherType)(OpaqueModel) |
| 417 | +dpex_data_model_manager = _init_kernel_data_model_manager() |
| 418 | +dpjit_data_model_manager = _init_dpjit_data_model_manager() |
0 commit comments