-
-
Notifications
You must be signed in to change notification settings - Fork 615
Description
In PR #960 outdims
function is added to Flux.jl and makes it easier for us to infer the output dimension for our neural network model ( especially when using convolutional layer ). However, the results of outdims
function is incorrect for chained layers, and existing test cases fail to reflect this flaw. e.g. Try following code:
using Flux
using Flux: outdims
all_dense_model = Chain(Dense(2, 10), Dense(10, 4))
outdims(all_dense_model, 2) # will return (10,) while it should be (4,)
I'm using Flux v0.10.3 in Julia v1.3.0 under macOS
Currently, outdims
is implemented as:
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))(isize)
however, foldl
doesn't give correct function composition, e.g. Try following code:
f1(x) = 2x; f2(x) = exp(x); f3(x) = x^3
const x0 = 0.5
# we want to compute f3∘(f2∘f1(x))
f_true = f3∘f2∘f1
f_true(x0) # 20.085536923187664
# however
f = foldl(∘, (f1, f2, f3)) # the returned f = ((f1∘f2)∘f3) following left association rule !!!
f(x0) # 2.2662969061336526
I suggest to resolve this problem by the following modifications:
# edit
using Base: tail # thanks to @darsnack for pointing this out
outdims(t::Tuple, isize) = outdims(tail(t), outdims(first(t), isize))
outdims(c::Chain, isize) = outdims(c.layers, isize)
and add on following test cases:
@testset "dense layer" begin
X = randn(Float32, 5, 10)
D0, D1, D2 = 5, 100, 25
dense1 = Dense(D0, D1, relu)
dense2 = Dense(D1, D2)
dense_chain = Chain(dense1, dense2)
@test outdims(dense1, D0) = (D1,)
@test first(outdims(dense1, D0)) == size(dense1(X), 1)
@test first(outdims(dense_chain, D0)) == size(dense_chain(X), 1)
end
@testset "conv layer" begin
X = randn(Float32, 28, 28, 1, 1)
D0, D1 = 3, 5
S, P = 3, 1
conv1_stride = Conv((D0, D0), 16=>32, stride=S, pad=P)
conv2 = Conv((D1, D1), 3=>16)
conv_chain = Chain(conv1_stride, conv2)
@test typeof(outdims(conv1_stride, (28, 28))) <: Tuple
@test length(outdims(conv1_stride, (28, 28))) == 2
@test outdims(conv1_stride, (28, 28)) == size(conv1_stride(X))[1:2]
@test outdims(conv_chain, (28, 28)) == size(conv_chain(X))[1:2]
end
Maybe we could also consider extend outdims
function to more complex chained model, e.g. model that contain normal julia function, such as x->reshape(x, :, 4)
, and also write outdims
for recurrent layers.
I'll collect these into a PR if you'd like to @MikeInnes @baggepinnen @darsnack .