Skip to content

Commit e601347

Browse files
committed
Test in-place plans
1 parent e137ae3 commit e601347

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,18 @@ end
161161

162162
# plans
163163
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
164-
y = P * x
164+
y = P * x
165+
if Base.mightalias(y, x)
166+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
167+
end
165168
Δy = P * Δx
166169
return y, Δy
167170
end
168171
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
169172
y = P * x
173+
if Base.mightalias(y, x)
174+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
175+
end
170176
project_x = ChainRulesCore.ProjectTo(x)
171177
Pt = P'
172178
function mul_plan_pullback(ȳ)
@@ -178,11 +184,17 @@ end
178184

179185
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
180186
y = P * x
187+
if Base.mightalias(y, x)
188+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
189+
end
181190
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
182191
return y, Δy
183192
end
184193
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
185194
y = P * x
195+
if Base.mightalias(y, x)
196+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
197+
end
186198
Pt = P'
187199
scale = P.scale
188200
project_x = ChainRulesCore.ProjectTo(x)

test/runtests.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,27 @@ end
6868
@test fftdims(P) == dims
6969
end
7070

71+
# in-place plan
72+
P = plan_fft!(x, dims)
73+
@test eltype(P) === ComplexF64
74+
xc64 = ComplexF64.(x)
75+
@test P * xc64 fftw_fft
76+
@test xc64 fftw_fft
77+
7178
fftw_bfft = complex.(size(x, dims) .* x)
7279
@test AbstractFFTs.bfft(y, dims) fftw_bfft
7380
P = plan_bfft(x, dims)
7481
@test P * y fftw_bfft
7582
@test P \ (P * y) y
7683
@test fftdims(P) == dims
7784

85+
# in-place plan
86+
P = plan_bfft!(x, dims)
87+
@test eltype(P) === ComplexF64
88+
yc64 = ComplexF64.(y)
89+
@test P * yc64 fftw_bfft
90+
@test yc64 fftw_bfft
91+
7892
fftw_ifft = complex.(x)
7993
@test AbstractFFTs.ifft(y, dims) fftw_ifft
8094
# test plan_ifft and also inv and plan_inv of plan_fft, which should all give
@@ -86,6 +100,13 @@ end
86100
@test fftdims(P) == dims
87101
end
88102

103+
# in-place plan
104+
P = plan_ifft!(x, dims)
105+
@test eltype(P) === ComplexF64
106+
yc64 = ComplexF64.(y)
107+
@test P * yc64 fftw_ifft
108+
@test yc64 fftw_ifft
109+
89110
# real FFT
90111
fftw_rfft = fftw_fft[
91112
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
@@ -361,7 +382,8 @@ end
361382
for x_shape in ((2,), (2, 3), (3, 4, 5))
362383
N = length(x_shape)
363384
x = randn(x_shape)
364-
complex_x = randn(ComplexF64, x_shape)
385+
complex_x = randn(ComplexF64, x_shape)
386+
Δ = (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesTestUtils.rand_tangent(complex_x))
365387
for dims in unique((1, 1:N, N))
366388
# fft, ifft, bfft
367389
for f in (fft, ifft, bfft)
@@ -370,11 +392,14 @@ end
370392
test_frule(f, complex_x, dims)
371393
test_rrule(f, complex_x, dims)
372394
end
373-
for pf in (plan_fft, plan_ifft, plan_bfft)
395+
for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!))
374396
test_frule(*, pf(x, dims), x)
375397
test_rrule(*, pf(x, dims), x)
376398
test_frule(*, pf(complex_x, dims), complex_x)
377399
test_rrule(*, pf(complex_x, dims), complex_x)
400+
401+
@test_throws ArgumentError ChainRulesCore.frule(Δ, *, pf!(complex_x, dims), complex_x)
402+
@test_throws ArgumentError ChainRulesCore.rrule(*, pf!(complex_x, dims), complex_x)
378403
end
379404

380405
# rfft
@@ -392,10 +417,10 @@ end
392417
test_rrule(f, complex_x, d, dims)
393418
end
394419
end
395-
for pf in (plan_irfft, plan_brfft)
420+
for pf in (plan_irfft, plan_brfft)
396421
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
397422
test_frule(*, pf(complex_x, d, dims), complex_x)
398-
test_rrule(*, pf(complex_x, d, dims), complex_x)
423+
test_rrule(*, pf(complex_x, d, dims), complex_x)
399424
end
400425
end
401426
end

test/testplans.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,25 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray)
232232

233233
return y
234234
end
235+
236+
# In-place plans
237+
# (simple wrapper of out-of-place plans that does not support inverses)
238+
struct InplaceTestPlan{T,P<:Plan{T}} <: Plan{T}
239+
plan::P
240+
end
241+
242+
Base.size(p::InplaceTestPlan) = size(p.plan)
243+
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
244+
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
245+
246+
function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
247+
return InplaceTestPlan(plan_fft(x, region; kwargs...))
248+
end
249+
function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...)
250+
return InplaceTestPlan(plan_bfft(x, region; kwargs...))
251+
end
252+
253+
function LinearAlgebra.mul!(y::AbstractArray, p::InplaceTestPlan, x::AbstractArray)
254+
return mul!(y, p.plan, x)
255+
end
256+
Base.:*(p::InplaceTestPlan, x::AbstractArray) = copyto!(x, p.plan * x)

0 commit comments

Comments
 (0)