@@ -103,17 +103,38 @@ end
103103
104104@device_override  Base. log (x:: Float64 ) =  ccall (" extern __nv_log"  , llvmcall, Cdouble, (Cdouble,), x)
105105@device_override  Base. log (x:: Float32 ) =  ccall (" extern __nv_logf"  , llvmcall, Cfloat, (Cfloat,), x)
106+ @device_override  function  Base. log (x:: Float16 )
107+     if  compute_capability () >=  sv " 8.0" 
108+         ccall (" extern __nv_hlog"  , llvmcall, Float16, (Float16,), x)
109+     else 
110+         return  Float16 (log (Float32 (x)))
111+     end 
112+ end 
106113@device_override  FastMath. log_fast (x:: Float32 ) =  ccall (" extern __nv_fast_logf"  , llvmcall, Cfloat, (Cfloat,), x)
107114
108115@device_override  Base. log10 (x:: Float64 ) =  ccall (" extern __nv_log10"  , llvmcall, Cdouble, (Cdouble,), x)
109116@device_override  Base. log10 (x:: Float32 ) =  ccall (" extern __nv_log10f"  , llvmcall, Cfloat, (Cfloat,), x)
117+ @device_override  function  Base. log10 (x:: Float16 )
118+     if  compute_capability () >=  sv " 8.0" 
119+         ccall (" extern __nv_hlog10"  , llvmcall, Float16, (Float16,), x)
120+     else 
121+         return  Float16 (log10 (Float32 (x)))
122+     end 
123+ end 
110124@device_override  FastMath. log10_fast (x:: Float32 ) =  ccall (" extern __nv_fast_log10f"  , llvmcall, Cfloat, (Cfloat,), x)
111125
112126@device_override  Base. log1p (x:: Float64 ) =  ccall (" extern __nv_log1p"  , llvmcall, Cdouble, (Cdouble,), x)
113127@device_override  Base. log1p (x:: Float32 ) =  ccall (" extern __nv_log1pf"  , llvmcall, Cfloat, (Cfloat,), x)
114128
115129@device_override  Base. log2 (x:: Float64 ) =  ccall (" extern __nv_log2"  , llvmcall, Cdouble, (Cdouble,), x)
116130@device_override  Base. log2 (x:: Float32 ) =  ccall (" extern __nv_log2f"  , llvmcall, Cfloat, (Cfloat,), x)
131+ @device_override  function  Base. log2 (x:: Float16 )
132+     if  compute_capability () >=  sv " 8.0" 
133+         ccall (" extern __nv_hlog2"  , llvmcall, Float16, (Float16,), x)
134+     else 
135+         return  Float16 (log (Float32 (x)))
136+     end 
137+ end 
117138@device_override  FastMath. log2_fast (x:: Float32 ) =  ccall (" extern __nv_fast_log2f"  , llvmcall, Cfloat, (Cfloat,), x)
118139
119140@device_function  logb (x:: Float64 ) =  ccall (" extern __nv_logb"  , llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +148,35 @@ end
127148
128149@device_override  Base. exp (x:: Float64 ) =  ccall (" extern __nv_exp"  , llvmcall, Cdouble, (Cdouble,), x)
129150@device_override  Base. exp (x:: Float32 ) =  ccall (" extern __nv_expf"  , llvmcall, Cfloat, (Cfloat,), x)
151+ @device_override  function  Base. exp (x:: Float16 )
152+     if  compute_capability () >=  sv " 8.0" 
153+         ccall (" extern __nv_hexp"  , llvmcall, Float16, (Float16,), x)
154+     else 
155+         return  Float16 (exp (Float32 (x)))
156+     end 
157+ end 
130158@device_override  FastMath. exp_fast (x:: Float32 ) =  ccall (" extern __nv_fast_expf"  , llvmcall, Cfloat, (Cfloat,), x)
131159
132160@device_override  Base. exp2 (x:: Float64 ) =  ccall (" extern __nv_exp2"  , llvmcall, Cdouble, (Cdouble,), x)
133161@device_override  Base. exp2 (x:: Float32 ) =  ccall (" extern __nv_exp2f"  , llvmcall, Cfloat, (Cfloat,), x)
162+ @device_override  function  Base. exp2 (x:: Float16 )
163+     if  compute_capability () >=  sv " 8.0" 
164+         ccall (" extern __nv_hexp2"  , llvmcall, Float16, (Float16,), x)
165+     else 
166+         return  Float16 (exp2 (Float32 (x)))
167+     end 
168+ end 
134169@device_override  FastMath. exp2_fast (x:: Union{Float32, Float64} ) =  exp2 (x)
135- #  TODO : enable once PTX > 7.0 is supported
136- #  @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137170
138171@device_override  Base. exp10 (x:: Float64 ) =  ccall (" extern __nv_exp10"  , llvmcall, Cdouble, (Cdouble,), x)
139172@device_override  Base. exp10 (x:: Float32 ) =  ccall (" extern __nv_exp10f"  , llvmcall, Cfloat, (Cfloat,), x)
173+ @device_override  function  Base. exp10 (x:: Float16 )
174+     if  compute_capability () >=  sv " 8.0" 
175+         ccall (" extern __nv_hexp10"  , llvmcall, Float16, (Float16,), x)
176+     else 
177+         return  Float16 (exp10 (Float32 (x)))
178+     end 
179+ end 
140180@device_override  FastMath. exp10_fast (x:: Float32 ) =  ccall (" extern __nv_fast_exp10f"  , llvmcall, Cfloat, (Cfloat,), x)
141181
142182@device_override  Base. expm1 (x:: Float64 ) =  ccall (" extern __nv_expm1"  , llvmcall, Cdouble, (Cdouble,), x)
204244
205245@device_override  Base. isnan (x:: Float64 ) =  (ccall (" extern __nv_isnand"  , llvmcall, Int32, (Cdouble,), x)) !=  0 
206246@device_override  Base. isnan (x:: Float32 ) =  (ccall (" extern __nv_isnanf"  , llvmcall, Int32, (Cfloat,), x)) !=  0 
247+ @device_override  function  Base. isnan (x:: Float16 )
248+     if  compute_capability () >=  sv " 8.0" 
249+         return  (ccall (" extern __nv_hisnan"  , llvmcall, Int32, (Float16,), x)) !=  0 
250+     else 
251+         return  isnan (Float32 (x))
252+     end 
253+ end 
207254
208255@device_function  nearbyint (x:: Float64 ) =  ccall (" extern __nv_nearbyint"  , llvmcall, Cdouble, (Cdouble,), x)
209256@device_function  nearbyint (x:: Float32 ) =  ccall (" extern __nv_nearbyintf"  , llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +270,26 @@ end
223270@device_override  Base. abs (x:: Int32 ) =    ccall (" extern __nv_abs"  , llvmcall, Int32, (Int32,), x)
224271@device_override  Base. abs (f:: Float64 ) =  ccall (" extern __nv_fabs"  , llvmcall, Cdouble, (Cdouble,), f)
225272@device_override  Base. abs (f:: Float32 ) =  ccall (" extern __nv_fabsf"  , llvmcall, Cfloat, (Cfloat,), f)
226- #  TODO : enable once PTX > 7.0 is supported
227- #  @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
273+ @device_override  function  Base. abs (f:: Float16 )
274+     if  compute_capability () >=  sv " 8.0" 
275+         ccall (" extern __nv_habs"  , llvmcall, Float16, (Float16,), f)
276+     else 
277+         return  Float16 (abs (Float32 (f)))
278+     end 
279+ end 
228280@device_override  Base. abs (x:: Int64 ) =    ccall (" extern __nv_llabs"  , llvmcall, Int64, (Int64,), x)
229281
230282# # roots and powers
231283
232284@device_override  Base. sqrt (x:: Float64 ) =  ccall (" extern __nv_sqrt"  , llvmcall, Cdouble, (Cdouble,), x)
233285@device_override  Base. sqrt (x:: Float32 ) =  ccall (" extern __nv_sqrtf"  , llvmcall, Cfloat, (Cfloat,), x)
286+ @device_override  function  Base. sqrt (x:: Float16 )
287+     if  compute_capability () >=  sv " 8.0" 
288+         ccall (" extern __nv_hsqrt"  , llvmcall, Float16, (Float16,), x)
289+     else 
290+         return  Float16 (sqrt (Float32 (x)))
291+     end 
292+ end 
234293@device_override  FastMath. sqrt_fast (x:: Union{Float32, Float64} ) =  sqrt (x)
235294
236295@device_function  rsqrt (x:: Float64 ) =  ccall (" extern __nv_rsqrt"  , llvmcall, Cdouble, (Cdouble,), x)
295354#  JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
296355# @device_override Base.min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
297356# @device_override Base.min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
357+ @device_override  @inline  function  Base. min (x:: Float16 , y:: Float16 )
358+     if  compute_capability () >=  sv " 8.0" 
359+         return  ccall (" extern __nv_hmin"  , llvmcall, Float16, (Float16, Float16), x, y)
360+     else 
361+         return  Float16 (min (Float32 (x), Float32 (y)))
362+     end 
363+ end 
298364@device_override  @inline  function  Base. min (x:: Float32 , y:: Float32 )
299365    if  @static  LLVM. version () <  v " 14"   ?  false  :  (compute_capability () >=  sv " 8.0"  )
300366        #  LLVM 14+ can do the right thing, but only on sm_80+
321387#  JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
322388# @device_override Base.max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
323389# @device_override Base.max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
390+ @device_override  @inline  function  Base. max (x:: Float16 , y:: Float16 )
391+     if  compute_capability () >=  sv " 8.0" 
392+         return  ccall (" extern __nv_hmax"  , llvmcall, Float16, (Float16, Float16), x, y)
393+     else 
394+         return  Float16 (max (Float32 (x), Float32 (y)))
395+     end 
396+ end 
324397@device_override  @inline  function  Base. max (x:: Float32 , y:: Float32 )
325398    if  @static  LLVM. version () <  v " 14"   ?  false  :  (compute_capability () >=  sv " 8.0"  )
326399        #  LLVM 14+ can do the right thing, but only on sm_80+
0 commit comments