diff --git a/Project.toml b/Project.toml index c254b25..35f4e26 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" diff --git a/src/extensions/Flux.jl b/src/extensions/Flux.jl index ade30f2..a49981d 100644 --- a/src/extensions/Flux.jl +++ b/src/extensions/Flux.jl @@ -1,8 +1,9 @@ -export seed_glorot_normal, seed_glorot_uniform +export seed_glorot_normal, seed_glorot_uniform, seed_orthogonal import Flux: glorot_uniform, glorot_normal using Random +using LinearAlgebra glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(Flux.nfan(dims...))) @@ -13,3 +14,22 @@ seed_glorot_uniform(; seed = nothing) = (dims...) -> glorot_uniform(MersenneTwister(seed), dims...) seed_glorot_normal(; seed = nothing) = (dims...) -> glorot_normal(MersenneTwister(seed), dims...) + +# https://github.com/FluxML/Flux.jl/pull/1171/ +# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/Orthogonal +function orthogonal_matrix(rng::AbstractRNG, nrow, ncol) + shape = reverse(minmax(nrow, ncol)) + a = randn(rng, Float32, shape) + q, r = qr(a) + q = Matrix(q) * diagm(sign.(diag(r))) + nrow < ncol ? permutedims(q) : q +end + +function orthogonal(rng::AbstractRNG, d1, rest_dims...) + m = orthogonal_matrix(rng, d1, *(rest_dims...)) + reshape(m, d1, rest_dims...) +end + +orthogonal(dims...) = orthogonal(Random.GLOBAL_RNG, dims...) + +seed_orthogonal(;seed = nothing) = (dims...) -> orthogonal(MersenneTwister(seed), dims...) \ No newline at end of file