Skip to content

Commit a532b00

Browse files
committed
add within_gradient
1 parent 806b0ef commit a532b00

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlib"
22
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3-
version = "0.8.9"
3+
version = "0.8.10"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
"""
2+
within_gradient(x) --> Bool
3+
4+
Returns `false` except when used inside a `gradient` call, when it returns `true`.
5+
Useful for Flux regularisation layers which behave differently during training and inference.
6+
7+
Works with any ChainRules-based differentiation package, in which case `x` is ignored.
8+
But Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should
9+
pass it an array whose gradient is of interest.
10+
"""
11+
within_gradient(x) = false
12+
13+
ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent())
14+
15+
116
"""
217
safe_div(x, y)
318

test/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
@testset "within_gradient" begin
2+
@test NNlib.within_gradient([1.0]) === false
3+
@test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,)
4+
end
5+
16
@testset "maximum_dims" begin
27
ind1 = [1,2,3,4,5,6]
38
@test NNlib.maximum_dims(ind1) == (6,)

0 commit comments

Comments
 (0)