@@ -143,63 +143,6 @@ def sub_group_barrier_impl(context, builder, sig, args):
143143 return _void_value
144144
145145
146- def insert_and_call_atomic_fn (
147- context , builder , sig , fn_type , dtype , ptr , val , addrspace
148- ):
149- ll_p = None
150- name = ""
151- if dtype .name == "float32" :
152- ll_val = llvmir .FloatType ()
153- ll_p = ll_val .as_pointer ()
154- if fn_type == "add" :
155- name = "numba_dpex_atomic_add_f32"
156- elif fn_type == "sub" :
157- name = "numba_dpex_atomic_sub_f32"
158- else :
159- raise TypeError ("Operation type is not supported %s" % (fn_type ))
160- elif dtype .name == "float64" :
161- if True :
162- ll_val = llvmir .DoubleType ()
163- ll_p = ll_val .as_pointer ()
164- if fn_type == "add" :
165- name = "numba_dpex_atomic_add_f64"
166- elif fn_type == "sub" :
167- name = "numba_dpex_atomic_sub_f64"
168- else :
169- raise TypeError (
170- "Operation type is not supported %s" % (fn_type )
171- )
172- else :
173- raise TypeError (
174- "Atomic operation is not supported for type %s" % (dtype .name )
175- )
176-
177- if addrspace == address_space .LOCAL :
178- name = name + "_local"
179- else :
180- name = name + "_global"
181-
182- assert ll_p is not None
183- assert name != ""
184- ll_p .addrspace = address_space .GENERIC
185-
186- mod = builder .module
187- if sig .return_type == types .void :
188- llretty = llvmir .VoidType ()
189- else :
190- llretty = context .get_value_type (sig .return_type )
191-
192- llargs = [ll_p , context .get_value_type (sig .args [2 ])]
193- fnty = llvmir .FunctionType (llretty , llargs )
194-
195- fn = cgutils .get_or_insert_function (mod , fnty , name )
196- fn .calling_convention = kernel_target .CC_SPIR_FUNC
197-
198- generic_ptr = context .addrspacecast (builder , ptr , address_space .GENERIC )
199-
200- return builder .call (fn , [generic_ptr , val ])
201-
202-
203146def native_atomic_add (context , builder , sig , args ):
204147 aryty , indty , valty = sig .args
205148 ary , inds , val = args
@@ -286,20 +229,15 @@ def native_atomic_add(context, builder, sig, args):
286229@lower (stubs .atomic .add , types .Array , types .UniTuple , types .Any )
287230@lower (stubs .atomic .add , types .Array , types .Tuple , types .Any )
288231def atomic_add_tuple (context , builder , sig , args ):
289- device_type = dpctl .get_current_queue ().sycl_device .device_type
290232 dtype = sig .args [0 ].dtype
291233
292- if dtype == types .float32 or dtype == types .float64 :
293- if (
294- device_type == dpctl .device_type .gpu
295- and config .NATIVE_FP_ATOMICS == 1
296- ):
297- return native_atomic_add (context , builder , sig , args )
298- else :
299- # Currently, DPCPP only supports native floating point
300- # atomics for GPUs.
301- return atomic_add (context , builder , sig , args , "add" )
302- elif dtype == types .int32 or dtype == types .int64 :
234+ # TODO: do we need this check, or should we just use native_atomic_add for everything?
235+ if (
236+ dtype == types .float32
237+ or dtype == types .float64
238+ or dtype == types .int32
239+ or dtype == types .int64
240+ ):
303241 return native_atomic_add (context , builder , sig , args )
304242 else :
305243 raise TypeError ("Atomic operation on unsupported type %s" % dtype )
@@ -337,83 +275,19 @@ def atomic_sub_wrapper(context, builder, sig, args):
337275@lower (stubs .atomic .sub , types .Array , types .UniTuple , types .Any )
338276@lower (stubs .atomic .sub , types .Array , types .Tuple , types .Any )
339277def atomic_sub_tuple (context , builder , sig , args ):
340- device_type = dpctl .get_current_queue ().sycl_device .device_type
341278 dtype = sig .args [0 ].dtype
342279
343- if dtype == types .float32 or dtype == types .float64 :
344- if (
345- device_type == dpctl .device_type .gpu
346- and config .NATIVE_FP_ATOMICS == 1
347- ):
348- return atomic_sub_wrapper (context , builder , sig , args )
349- else :
350- # Currently, DPCPP only supports native floating point
351- # atomics for GPUs.
352- return atomic_add (context , builder , sig , args , "sub" )
353- elif dtype == types .int32 or dtype == types .int64 :
280+ if (
281+ dtype == types .float32
282+ or dtype == types .float64
283+ or dtype == types .int32
284+ or dtype == types .int64
285+ ):
354286 return atomic_sub_wrapper (context , builder , sig , args )
355287 else :
356288 raise TypeError ("Atomic operation on unsupported type %s" % dtype )
357289
358290
359- def atomic_add (context , builder , sig , args , name ):
360- from .atomics import atomic_support_present
361-
362- if atomic_support_present ():
363- context .extra_compile_options [kernel_target .LINK_ATOMIC ] = True
364- aryty , indty , valty = sig .args
365- ary , inds , val = args
366- dtype = aryty .dtype
367-
368- if indty == types .intp :
369- indices = [inds ] # just a single integer
370- indty = [indty ]
371- else :
372- indices = cgutils .unpack_tuple (builder , inds , count = len (indty ))
373- indices = [
374- context .cast (builder , i , t , types .intp )
375- for t , i in zip (indty , indices )
376- ]
377-
378- if dtype != valty :
379- raise TypeError ("expecting %s but got %s" % (dtype , valty ))
380-
381- if aryty .ndim != len (indty ):
382- raise TypeError (
383- "indexing %d-D array with %d-D index" % (aryty .ndim , len (indty ))
384- )
385-
386- lary = context .make_array (aryty )(context , builder , ary )
387- ptr = cgutils .get_item_pointer (context , builder , aryty , lary , indices )
388-
389- if isinstance (aryty , Array ) and aryty .addrspace == address_space .LOCAL :
390- return insert_and_call_atomic_fn (
391- context ,
392- builder ,
393- sig ,
394- name ,
395- dtype ,
396- ptr ,
397- val ,
398- address_space .LOCAL ,
399- )
400- else :
401- return insert_and_call_atomic_fn (
402- context ,
403- builder ,
404- sig ,
405- name ,
406- dtype ,
407- ptr ,
408- val ,
409- address_space .GLOBAL ,
410- )
411- else :
412- raise ImportError (
413- "Atomic support is not present, can not perform atomic_add"
414- )
415-
416-
417291@lower (stubs .private .array , types .IntegerLiteral , types .Any )
418292def dpex_private_array_integer (context , builder , sig , args ):
419293 length = sig .args [0 ].literal_value
0 commit comments