@@ -4,7 +4,6 @@ export @groupreduce, @warp_groupreduce
4
4
@groupreduce op val neutral [groupsize]
5
5
6
6
Perform group reduction of `val` using `op`.
7
- If backend supports warp reduction, it will use it instead of thread reduction.
8
7
9
8
# Arguments
10
9
@@ -27,13 +26,6 @@ macro groupreduce(op, val, groupsize)
27
26
:(__thread_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), Val ($ (esc (groupsize)))))
28
27
end
29
28
30
- macro warp_groupreduce (op, val, neutral)
31
- :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
32
- end
33
- macro warp_groupreduce (op, val, neutral, groupsize)
34
- :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val ($ (esc (groupsize)))))
35
- end
36
-
37
29
function __thread_groupreduce (__ctx__, op, val:: T , :: Val{groupsize} ) where {T, groupsize}
38
30
storage = @localmem T groupsize
39
31
61
53
62
54
# Warp groupreduce.
63
55
64
- # NOTE: Backends should implement these two device functions (with `@device_override`).
56
+ """
57
+ @warp_groupreduce op val neutral [groupsize]
58
+
59
+ Perform group reduction of `val` using `op`.
60
+ Each warp within a workgroup performs its own reduction using [`shfl_down`](@ref) intrinsic,
61
+ followed by final reduction over results of individual warp reductions.
62
+
63
+ !!! note
64
+
65
+ Use [`supports_warp_reduction`](@ref) to query if given backend supports warp reduction.
66
+ """
67
+ macro warp_groupreduce (op, val, neutral)
68
+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
69
+ end
70
+ macro warp_groupreduce (op, val, neutral, groupsize)
71
+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val ($ (esc (groupsize)))))
72
+ end
73
+
74
+ """
75
+ shfl_down(val::T, offset::Integer)::T where T
76
+
77
+ Read `val` from a lane with higher id given by `offset`.
78
+ """
65
79
function shfl_down end
66
80
supports_warp_reduction () = false
67
- # Host-variant.
81
+
82
+ """
83
+ supports_warp_reduction(::Backend)
84
+
85
+ Query if given backend supports [`shfl_down`](@ref) intrinsic and thus warp reduction.
86
+ """
68
87
supports_warp_reduction (:: Backend ) = false
69
88
70
89
# Assume warp is 32 lanes.
0 commit comments