Skip to content

Commit 826a3e8

Browse files
committed
use in softmax too
1 parent 2a8f5a9 commit 826a3e8

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

src/softmax.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
6969
end
7070

7171
function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
72-
dx = if within_grad()
72+
dx = if within_gradient(y)
7373
tmp = dy .* y
7474
tmp .- y .* sum(tmp; dims)
7575
else
@@ -88,9 +88,6 @@ function rrule(::typeof(softmax), x; dims = 1)
8888
return y, softmax_pullback
8989
end
9090

91-
within_grad() = false
92-
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)
93-
9491
fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))
9592

9693
"""

0 commit comments

Comments
 (0)