Skip to content

outdims function doesn't work properly for chained layers #1086

@HamletWantToCode

Description

@HamletWantToCode

In PR #960 outdims function is added to Flux.jl and makes it easier for us to infer the output dimension for our neural network model ( especially when using convolutional layer ). However, the results of outdims function is incorrect for chained layers, and existing test cases fail to reflect this flaw. e.g. Try following code:

using Flux
using Flux: outdims

all_dense_model = Chain(Dense(2, 10), Dense(10, 4))
outdims(all_dense_model, 2)    # will return (10,) while it should be (4,)

I'm using Flux v0.10.3 in Julia v1.3.0 under macOS

Currently, outdims is implemented as:

outdims(c::Chain, isize) = foldl(, map(l -> (x -> outdims(l, x)), c.layers))(isize)

however, foldl doesn't give correct function composition, e.g. Try following code:

f1(x) = 2x; f2(x) = exp(x); f3(x) = x^3
const x0 = 0.5
# we want to compute f3∘(f2∘f1(x))
f_true = f3f2f1
f_true(x0)   # 20.085536923187664
# however
f = foldl(, (f1, f2, f3))   # the returned f = ((f1∘f2)∘f3) following left association rule !!!
f(x0) # 2.2662969061336526

I suggest to resolve this problem by the following modifications:

# edit
using Base: tail # thanks to @darsnack for pointing this out

outdims(t::Tuple, isize) = outdims(tail(t), outdims(first(t), isize))
outdims(c::Chain, isize) = outdims(c.layers, isize)

and add on following test cases:

@testset "dense layer" begin
         X = randn(Float32, 5, 10)
         D0, D1, D2 = 5, 100, 25
         dense1 = Dense(D0, D1, relu)
         dense2 = Dense(D1, D2)
         dense_chain = Chain(dense1, dense2)
         @test outdims(dense1, D0) = (D1,)
         @test first(outdims(dense1, D0)) == size(dense1(X), 1)
         @test first(outdims(dense_chain, D0)) == size(dense_chain(X), 1)
end

@testset "conv layer" begin
            X = randn(Float32, 28, 28, 1, 1)
            D0, D1 = 3, 5
            S, P = 3, 1
            conv1_stride = Conv((D0, D0), 16=>32, stride=S, pad=P)
            conv2 = Conv((D1, D1), 3=>16)
            conv_chain = Chain(conv1_stride, conv2)
            @test typeof(outdims(conv1_stride, (28, 28))) <: Tuple
            @test length(outdims(conv1_stride, (28, 28))) == 2
            @test outdims(conv1_stride, (28, 28)) == size(conv1_stride(X))[1:2]
            @test outdims(conv_chain, (28, 28)) == size(conv_chain(X))[1:2]
end

Maybe we could also consider extend outdims function to more complex chained model, e.g. model that contain normal julia function, such as x->reshape(x, :, 4) , and also write outdims for recurrent layers.

I'll collect these into a PR if you'd like to @MikeInnes @baggepinnen @darsnack .

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions