Skip to content
21 changes: 21 additions & 0 deletions docs/src/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,24 @@ Flux.@functor Affine
```

This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

## Utility functions

Flux provides some utility functions to help you generate models in an automated fashion.

`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size.
Currently limited to the following layers:
- `Chain`
- `Dense`
- `Conv`
- `Diagonal`
- `Maxout`
- `ConvTranspose`
- `DepthwiseConv`
- `CrossCor`
- `MaxPool`
- `MeanPool`

```@docs
outdims
```
27 changes: 27 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ function Base.show(io::IO, c::Chain)
print(io, ")")
end

"""
outdims(c::Chain, isize)

Calculate the output dimensions given the input dimensions, `isize`.

```julia
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
outdims(m, (10, 10)) == (6, 6)
```
"""
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize)

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Expand Down Expand Up @@ -116,6 +127,19 @@ end
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
outdims(l::Dense, isize)

Calculate the output dimensions given the input dimensions, `isize`.

```julia
m = Dense(10, 5)
outdims(m, (5, 2)) == (5,)
outdims(m, (10,)) == (5,)
```
"""
outdims(l::Dense, isize) = (size(l.W)[1],)

"""
Diagonal(in::Integer)

Expand Down Expand Up @@ -145,6 +169,7 @@ function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end

outdims(l::Diagonal, isize) = (length(l.α),)

"""
Maxout(over)
Expand Down Expand Up @@ -193,6 +218,8 @@ function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

outdims(l::Maxout, isize) = outdims(first(l.over), isize)

"""
SkipConnection(layers, connection)

Expand Down
35 changes: 34 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
using NNlib: conv, ∇conv_data, depthwiseconv
using NNlib: conv, ∇conv_data, depthwiseconv, output_size

# pad dims of x with dims of y until ndims(x) == ndims(y)
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)

_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end])

expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -68,6 +73,21 @@ end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
outdims(l::Conv, isize::Tuple)

Calculate the output dimensions given the input dimensions, `isize`.
Batch size and channel size are ignored as per `NNlib.jl`.

```julia
m = Conv((3, 3), 3 => 16)
outdims(m, (10, 10)) == (8, 8)
outdims(m, (10, 10, 1, 3)) == (8, 8)
```
"""
outdims(l::Conv, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
Expand Down Expand Up @@ -140,6 +160,9 @@ end

(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)

"""
DepthwiseConv(size, in=>out)
DepthwiseConv(size, in=>out, relu)
Expand Down Expand Up @@ -204,6 +227,9 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

outdims(l::DepthwiseConv, isize) =
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Expand Down Expand Up @@ -275,6 +301,9 @@ end
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

outdims(l::CrossCor, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
MaxPool(k)

Expand Down Expand Up @@ -304,6 +333,8 @@ function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end

outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))

"""
MeanPool(k)

Expand Down Expand Up @@ -331,3 +362,5 @@ end
function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end

outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
15 changes: 15 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,19 @@ import Flux: activations
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end
end

@testset "output dimensions" begin
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
@test Flux.outdims(m, (10, 10)) == (6, 6)

m = Dense(10, 5)
@test Flux.outdims(m, (5, 2)) == (5,)
@test Flux.outdims(m, (10,)) == (5,)

m = Flux.Diagonal(10)
@test Flux.outdims(m, (10,)) == (10,)

m = Maxout(() -> Conv((3, 3), 3 => 16), 2)
@test Flux.outdims(m, (10, 10)) == (8, 8)
end
end
52 changes: 52 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,55 @@ end
true
end
end

@testset "conv output dimensions" begin
m = Conv((3, 3), 3 => 16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have some tests for different padding, stride and dilation

@test Flux.outdims(m, (10, 10)) == (8, 8)
m = Conv((3, 3), 3 => 16; stride = 2)
@test Flux.outdims(m, (5, 5)) == (2, 2)
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3)
@test Flux.outdims(m, (5, 5)) == (5, 5)
m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
@test Flux.outdims(m, (5, 5)) == (4, 4)

m = ConvTranspose((3, 3), 3 => 16)
@test Flux.outdims(m, (8, 8)) == (10, 10)
m = ConvTranspose((3, 3), 3 => 16; stride = 2)
@test Flux.outdims(m, (2, 2)) == (5, 5)
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3)
@test Flux.outdims(m, (5, 5)) == (5, 5)
m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
@test Flux.outdims(m, (4, 4)) == (5, 5)

m = DepthwiseConv((3, 3), 3 => 6)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = DepthwiseConv((3, 3), 3 => 6; stride = 2)
@test Flux.outdims(m, (5, 5)) == (2, 2)
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3)
@test Flux.outdims(m, (5, 5)) == (5, 5)
m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2)
@test Flux.outdims(m, (5, 5)) == (4, 4)

m = CrossCor((3, 3), 3 => 16)
@test Flux.outdims(m, (10, 10)) == (8, 8)
m = CrossCor((3, 3), 3 => 16; stride = 2)
@test Flux.outdims(m, (5, 5)) == (2, 2)
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3)
@test Flux.outdims(m, (5, 5)) == (5, 5)
m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2)
@test Flux.outdims(m, (5, 5)) == (4, 4)

m = MaxPool((2, 2))
@test Flux.outdims(m, (10, 10)) == (5, 5)
m = MaxPool((2, 2); stride = 1)
@test Flux.outdims(m, (5, 5)) == (4, 4)
m = MaxPool((2, 2); stride = 2, pad = 3)
@test Flux.outdims(m, (5, 5)) == (5, 5)

m = MeanPool((2, 2))
@test Flux.outdims(m, (10, 10)) == (5, 5)
m = MeanPool((2, 2); stride = 1)
@test Flux.outdims(m, (5, 5)) == (4, 4)
m = MeanPool((2, 2); stride = 2, pad = 3)
@test Flux.outdims(m, (5, 5)) == (5, 5)
end